"""
Import needed packages
"""
from RFS_settings import *
import RFS_Functions.Utilis as Util
import RFS_Functions.Structural as Str


"""
Load job-postings data samples
"""
Select = pd.read_pickle(os.path.join(out_path, "df_est.pkl"))
deltas = [0.01, 0.03, 0.1]


"""
Structural Estimation
"""
for dd in deltas:
    print('###################################################################################################')
    print('Starting analysis:')
    print(' - Sample size:', Select.shape[0])
    print(' - Unique employers:', Select.Employer_ID.unique().shape[0])
    print(' - Initialization: from zero')
    print(' - Delta: ', str(dd))
    print(' - Fixed Effects: firm and time')
    print('###################################################################################################')
    if dd == 0.03:
        fig, ax = plt.subplots()
        Select['yearmonth'] = pd.to_datetime(Select['yearmonth'])
        Select.plot(x='yearmonth', y=['w_AI', 'w_OT', 'w_DM'],
                    color=[plots_colors['AI'], plots_colors['OldTech'], plots_colors['DataMgmt']],
                    ax=ax, figsize=(8, 5))
        plt.title('Salary variables utilized in this estimation', fontsize=18)
        plt.xticks(fontsize=18, rotation=0)
        plt.yticks(fontsize=18)
        plt.xlabel('Date', fontsize=18)
        plt.ylabel('$', fontsize=18)
        plt.legend(fontsize=18)
        # Set the x-axis tick locator and formatter
        ax.xaxis.set_major_locator(mdates.YearLocator())
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
        plt.savefig(os.path.join(plot_path, 'Figure_4.jpg'), bbox_inches='tight')
        plt.show()
    # Set-up parameters to be estimated with their characteristics
    params = lmfit.Parameters()
    params.add('delta', value=dd, min=0.0001, max=0.99, vary=False)
    params.add('alpha', value=0.5, min=0.1, max=0.9)
    params['alpha'].set(min=0.01, max=0.99, brute_step=0.1)
    params.add('gamma', value=0.5, min=0.1, max=0.9)
    params['gamma'].set(min=0.01, max=0.99, brute_step=0.1)
    params.add('phi', value=0.5, min=0.01, max=0.99)
    params['phi'].set(min=0.01, max=0.99, brute_step=0.1)
    params.add('d_0_av', value=100, min=0.01)
    params['d_0_av'].set(min=0.001, max=10000, brute_step=0.01)
    # Organize sample data into numpy arrays needed for structural estimation
    r, w_ai, l_ai, w_ot, l_ot, w_dm, l_dm, yearmonth, emp = Util.df_unpack(Select)
    # Add Finance risk premium to risk-free rate
    r = r + (0.038 / 12)
    # Create instance of Structural estimation class
    StructuralObj = Str.Structural(r, w_ai, l_ai, w_ot, l_ot, w_dm, l_dm, yearmonth, emp,
                                   FE='both', FE_tol=1e-10, AI_emp=Select[Select.L_AI != 0].Employer_ID.nunique(),
                                   OT_emp=Select[Select.L_OldTech != 0].Employer_ID.nunique()
                                   )
    # Solve non-linear least squares problem for parameters value
    tic = time.time()
    fitter = lmfit.Minimizer(StructuralObj.model_residuals, params)
    out = fitter.minimize(method='leastsq', maxfev=10000000000000, ftol=1.e-20)
    # Get the index for alpha and gamma in estimated covariance matrix
    alpha_index = out.var_names.index('alpha')
    gamma_index = out.var_names.index('gamma')
    out1 = {'alpha': out.params['alpha'].value,
            'gamma': out.params['gamma'].value,
            'delta': out.params['delta'].value,
            'phi': out.params['phi'].value,
            'd_0_av': out.params['d_0_av'].value}
    toc = time.time()
    print('-------------------------------')
    print('Parameter    Value       Stderr')
    out_params = pd.DataFrame(columns=['variable', 'value', 'stderr'])
    for name, param in out.params.items():
        row = {'variable': [name], 'value': [param.value], 'stderr': [param.stderr]}
        temp = pd.DataFrame(row)
        print('{:7s} {:11.5f} {:11.5f}'.format(name, param.value, param.stderr))
        out_params = pd.concat([out_params, temp])
    out_params.to_csv(os.path.join(plot_path, 'Table_2_delta_' + str(dd) + '.csv'))
    print("Estimation time:", (toc - tic) / 60, 'min')
    print('Parameters:')
    for key in out.params.keys():
        print(out.params[key])


    """
    Estimation diagnostics
    """
    # Compute data process using estimated parameters
    alpha, gamma, delta, phi, d_0_av = Util.get_params(out1)
    sample1 = copy.deepcopy(Select)
    sample1 = sample1.rename(columns={'L_DataMgmt': 'l_dm', 'Employer_ID': 'emp',
                                      'L_OldTech': 'l_ot', 'L_AI': 'l_ai'})
    outcome_df = StructuralObj.run_outcome(alpha, gamma, phi, delta, d_0_av)
    outcome_df = Util.outcome_merge_labor(outcome_df, sample1)
    if dd == 0.03:
        Util.plot_productivity(outcome_df, plot_path, plots_colors)

    # Construction of supporting variables
    outcome_df['phi_part'] = ((1 - phi) * ((outcome_df.l_dm) ** (- phi))) / \
                             (outcome_df.GS1M + delta - 1)
    outcome_df['rhs_ident_ot'] = ((gamma) / (1 - gamma)) * outcome_df['phi_part']
    outcome_df['rhs_ident_ai'] = ((alpha) / (1 - alpha)) * outcome_df['phi_part']
    outcome_df['lambda_phi'] = (outcome_df.l_dm) ** (phi)
    outcome_df['lhs_part'] = (outcome_df.w_DM * outcome_df['lambda_phi'])
    outcome_df['lhs_ident_ot'] = 1 / (outcome_df['lhs_part'] / (outcome_df.w_OT * outcome_df.l_ot))
    outcome_df['lhs_ident_ai'] = 1 / (outcome_df['lhs_part'] / (outcome_df.w_AI * outcome_df.l_ai))
    outcome_df['lhs_ident_ai_ot'] = 1 / (outcome_df['lhs_part'] / (outcome_df.w_AI * outcome_df.l_ai +
                                                                   outcome_df.w_OT * outcome_df.l_ot))

    # Value of data
    tic = time.time()
    err, V1 = StructuralObj.value_fun_iteration(outcome_df, delta, phi)
    toc = time.time()
    print('Value function iteration completed in', toc - tic, 's')
    outcome_df['data_value'] = V1(outcome_df['d_it_rec'])

    if dd == 0.03:
        # Representing cumulative data value
        last_data = Util.plot_value_data(outcome_df, plot_path, plots_colors)

    # Check stability in marginal value of data assumption
    outcome_df['MV'] = \
        (V1(outcome_df['d_it_rec'] + 1) - V1(outcome_df['d_it_rec']) /
         V1(outcome_df['d_it_rec']))
    outcome_df = outcome_df.sort_values(['emp', 'yearmonth'])
    outcome_df['MV_lag1'] = outcome_df.groupby('emp')['MV'].shift(1)
    outcome_df['MVprime'] = (outcome_df['MV'] - outcome_df['MV_lag1']) / outcome_df['MV_lag1']
    print('Average MVprime: ', round((outcome_df['MVprime'].mean()) * 100, 2), '%')

    # Identification plots construction
    OT_df = outcome_df[(outcome_df.l_ai == 0) & (outcome_df.l_ot != 0) &
                       (outcome_df.l_dm != 0)]
    AI_df = outcome_df[(outcome_df.l_ai != 0) & (outcome_df.l_dm != 0) & (outcome_df.l_ot == 0)]
    AI_OT_df = outcome_df[(outcome_df.l_ai != 0) & (outcome_df.l_dm != 0) &
                          (outcome_df.l_ot != 0)]

    # Prediction for firms that have both AI and OT
    print('AI-OT sample')
    X = AI_OT_df[['lhs_ident_ot', 'lhs_ident_ai']].values.reshape(-1, 2)
    Y = AI_OT_df['d_it']
    x = X[:, 0]
    y = X[:, 1]
    z = Y
    x_pred = np.linspace(0, 80, 160)
    y_pred = np.linspace(0, 60, 120)
    xx_pred, yy_pred = np.meshgrid(x_pred, y_pred)
    model_viz = np.array([xx_pred.flatten(), yy_pred.flatten()]).T
    ols = linear_model.LinearRegression(fit_intercept=False)
    model_ai_ot = ols.fit(X, Y)
    predicted = model_ai_ot.predict(model_viz)
    r2 = model_ai_ot.score(X, Y)
    ols_ai_ot = sm.OLS(Y, X)
    ols_result_ai_ot = ols_ai_ot.fit()
    print(ols_result_ai_ot.summary())
    print('----------', end='\n\n')

    if dd == 0.03:
        # Plotting
        plt.style.use('default')
        fig = plt.figure(figsize=(7.5, 7))
        ax3 = fig.add_subplot(111, projection='3d')
        axes = [ax3]
        for ax in axes:
            ax.scatter(xx_pred.flatten(), yy_pred.flatten(), predicted, facecolor=(0, 0, 0, 0),
                       s=20, edgecolor='cyan')
            ax.plot(x, y, z, color='purple', zorder=15, linestyle='none', marker='o',
                    alpha=0.3, label='AI + OT firms')
            ax.set_xlabel('OT/DM total payments', fontsize=12)
            ax.set_ylabel('AI/DM total payments', fontsize=12)
            ax.set_zlabel('Data Stock', fontsize=12)
            ax.locator_params(nbins=4, axis='x')
            ax.locator_params(nbins=5, axis='x')
            ax.scatter(55, 45, 3500, marker='>', s=200, color='m')
            ax.scatter(60, 61.5, 8000, marker='^', s=200, color='m')
        ax3.view_init(elev=-0, azim=165)
        fig.text(0.32, 0.38, 'AI/DM \n slope = ' + str(round(model_ai_ot.coef_[1], 2)),
                 color='m', fontsize=12)
        fig.text(0.19, 0.52, 'OT/DM \n slope\n= ' + str(round(model_ai_ot.coef_[0], 2)),
                 color='m', fontsize=12)
        fig.tight_layout()
        plt.legend(loc=(0.1, 0.8))
        fig.suptitle('Sensitivity of data stock to \n Analysis-to-Data-Management total payments',
                     fontsize=16)
        plt.savefig(os.path.join(plot_path, 'Figure_5.jpg'), bbox_inches='tight')
        plt.show()

    # Average difference in AI and old tech hiring:
    sal = outcome_df.drop_duplicates('yearmonth')
    print('Difference in AI and OT average salary: ', (sal.w_AI - sal.w_OT).mean())
